import os
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms

def get_dataset(train_dirs, val_dir, batch_size=64, num_workers=8, shuffle_train=True, shuffle_val=False):
    # DeiT normalization (to [-1, 1] range)
    '''normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))'''  # bad for ImageNet-100 -_-
    normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                     std=(0.5, 0.5, 0.5))     # changed from all 0.5 to this

    # Training augmentations
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3/4, 4/3)),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=10),  # adjust num_ops/magnitude as desired
        transforms.ToTensor(),
        normalize,
        #transforms.RandomErasing(p=0.25)   # added this transformation
    ])

    # Validation preprocessing
    val_transform = transforms.Compose([
        transforms.Resize(256),  # shorter side
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    # Load all training dirs and combine them
    train_datasets = []
    for d in train_dirs:
        if not os.path.isdir(d):
            raise FileNotFoundError(f"Training directory not found: {d}")
        train_datasets.append(datasets.ImageFolder(root=d, transform=train_transform))
    
    train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
    from torch.utils.data import Subset
    #train_dataset = Subset(train_dataset, range(10000))  # Use 500 samples
    #print("Using a small subset of training dataset")

    # Validation dataset
    if not os.path.isdir(val_dir):
        raise FileNotFoundError(f"Validation directory not found: {val_dir}")
    val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transform)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train,
                              num_workers=num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle_val,
                            num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader



'''
# DataSet Loading for TensorFlow V 2.16
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import preprocess_input
import os
import glob

def get_dataset(train_dirs, val_dir, batch_size=64, shuffle= 5000, pref= 2, shuffle_val=False):
    # List all four part
    val_files = tf.data.Dataset.list_files(os.path.join(val_dir, "*/*.JPEG"), shuffle=False) #.take(100)
    # Use all subfolders as one dataset
    train_ds = tf.data.Dataset.list_files(
        [os.path.join(d, "*/*.JPEG") for d in train_dirs],
        shuffle=True
    )
    IMG_SIZE = 224  # For ImageNet compatibility
    def process_image(file_path):
        # Extract label from the directory name
        parts = tf.strings.split(file_path, os.sep)
        label = parts[-2]  # 'n01440764' for example
        # Convert label to integer using a lookup table
        table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(label_names),  # Defined below
                values=tf.constant(list(range(len(label_names))))
            ),
            default_value=-1
        )
        label_id = table.lookup(label)
        image = tf.io.read_file(file_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
        image = preprocess_input(image)
    #    image = tf.cast(image, tf.float16) / 255.0  # Normalize
        return image, label_id
    all_dirs = train_dirs
    label_names = sorted(set(
        folder
        for d in all_dirs
        for folder in os.listdir(d)
        if os.path.isdir(os.path.join(d, folder))
    ))
    
    #AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.map(process_image, num_parallel_calls=4) # replaced 4 with AUTOTUNE
    train_ds = train_ds.shuffle(shuffle)
    train_ds = train_ds.batch(batch_size).repeat()
    train_ds = train_ds.prefetch(pref)
    train_ds = train_ds.apply(tf.data.experimental.ignore_errors())
    options = tf.data.Options()
    options.experimental_threading.max_intra_op_parallelism = 1
    options.experimental_threading.private_threadpool_size = 4
    train_ds = train_ds.with_options(options)
    val_ds = val_files.map(process_image, num_parallel_calls=4)
    val_ds = val_ds.batch(batch_size).prefetch(pref)
    return train_ds, val_ds'''